#!/usr/bin/env python3

# ==== PR merged
# - if label, do cherry-pick
# - if no label, add tracking issue to next release as "Needs testing"

# ==== PR labeled
# - if open, add PR to current release as "Fix (open)"
# - if closed, remove from next release, do cherry-pick

# ==== PR unlabeled
# - if open, remove from current release as "Fix (open)"
# - if closed, ???

# ==== cherry-pick conflict resolved
# - confirm cherry-pick was actually done
# - remove PR from current release as "Fix (conflict)"
# - add tracking issue to current release as "Needs testing"

# ==== cherry-pick (internal)
# - run Git to get branches and cherry-pick
# - if success, push branches, add tracking issue to current release as "Needs testing"
# - if fail, notify and add PR to current release as "Fix (conflict)"

import argparse
import base64
import os
import re
import subprocess
import sys
from typing import Callable, Optional

import requests

import gql

TEST = os.environ.get("RELEASE_TEST") == "1"


ORG = "determined-ai"
CLONED_REMOTE = "origin"
ISSUES_REPO = "release-party-issues-test" if TEST else "release-party-issues"

CHERRY_PICK_LABEL = "to-cherry-pick"

NEEDS_TESTING_STATUS = "Needs testing"
FIX_OPEN_STATUS = "Fix (open)"
FIX_CONFLICT_STATUS = "Fix (conflict)"
FIX_UNRELEASED_STATUS = "Fix (unreleased)"
FIX_RELEASED_STATUS = "Fix (released)"

GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
CASPER_TOKEN = os.environ.get("CASPER_TOKEN", "")


def run(*args, check=True, quiet=False, **kwargs):
    kwargs = dict(kwargs)
    if not quiet:
        print(f"\n================ running: \x1b[36m{args}\x1b[m")
    return subprocess.run(args, check=check, **kwargs)


def run_capture(*args, **kwargs):
    return run(
        *args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, **kwargs
    )


def make_issue_for_pr(issue_repo_id: str, pr_id: str) -> str:
    pr_info = gql.get_pr_info(id=pr_id)["node"]
    pr_repo = pr_info["repository"]["name"]
    pr_num = pr_info["number"]
    pr_title = pr_info["title"]
    pr_body = pr_info["body"]
    pr_url = pr_info["url"]
    title = f"Test {pr_repo}#{pr_num} ({pr_title})"
    print(f"Creating tracking issue '{title}'")
    return gql.create_issue(
        repo=issue_repo_id, title=title, body=f"(copied from {pr_url})\n\n----\n\n{pr_body}"
    )["createIssue"]["issue"]["id"]


def get_project_status_ids(project_id: str, status: str):
    status_info = gql.get_status_field_info(project=project_id)["node"]["field"]
    field_id = status_info["id"]
    value_id = next(v["id"] for v in status_info["options"] if v["name"] == status)
    return field_id, value_id


def add_item_to_project(project_id: str, item_id: str, status: str) -> None:
    status_field_id, status_value_id = get_project_status_ids(project_id, status)
    item_id = gql.add_item_to_project(project=project_id, item=item_id)["addProjectV2ItemById"][
        "item"
    ]["id"]
    gql.set_project_item_status(
        project=project_id, item=item_id, field=status_field_id, value=status_value_id
    )


def set_project_item_status(project_id: str, item_id: str, status: str) -> None:
    status_field_id, status_value_id = get_project_status_ids(project_id, status)
    gql.set_project_item_status(
        project=project_id, item=item_id, field=status_field_id, value=status_value_id
    )


def set_project_pr_status(project_id: str, pr_id: str, status: str) -> None:
    status_field_id, status_value_id = get_project_status_ids(project_id, status)
    item_id = project_item_id_for_pr(project_id, pr_id)
    gql.set_project_item_status(
        project=project_id, item=item_id, field=status_field_id, value=status_value_id
    )


def add_tracking_issue_to_project(project_id: str, pr_id: str, status: str) -> None:
    issue_repo_id = gql.get_repo_id(owner=ORG, name=ISSUES_REPO)["repository"]["id"]
    issue_id = make_issue_for_pr(issue_repo_id, pr_id)
    add_item_to_project(project_id, issue_id, status)


def find_project(owner: str, query: str, filt: Callable[[dict], bool]) -> dict:
    all_projects = gql.search_projects(owner=owner, q=query)["organization"]["projectsV2"]["nodes"]
    return next(p for p in all_projects if filt(p))


def next_project_id() -> str:
    return find_project(
        ORG,
        "Next release",
        lambda p: p["title"] == ("TEST Next release" if TEST else "Next release"),
    )["id"]


def current_project_id() -> str:
    return find_project(
        ORG,
        "Current release",
        lambda p: p["title"].startswith("TEST Current release" if TEST else "Current release"),
    )["id"]


def project_item_id_for_pr(project_id: str, pr_id: str):
    after_cursor = None
    while True:
        items = gql.list_project_prs(project=project_id, after=after_cursor)["node"]["items"]
        for item in items["nodes"]:
            if item["content"] and item["content"]["id"] == pr_id:
                return item["id"]
        if not items["pageInfo"]["hasNextPage"]:
            break
        after_cursor = items["pageInfo"]["endCursor"]
    return None


def cherry_pick_skipping_empty(commit):
    out = run_capture("git", "cherry-pick", "-x", commit, check=False)
    try:
        out.check_returncode()
    except subprocess.CalledProcessError:
        if "The previous cherry-pick is now empty" in out.stderr:
            run("git", "cherry-pick", "--skip")
        else:
            print(out.stdout)
            print(out.stderr)
            raise


def cherry_pick_pr(pr_id: str) -> None:
    pr = gql.get_pr_merge_commit_and_url(id=pr_id)["node"]
    pr_commit = pr["mergeCommit"]["oid"]
    print(f"Cherry-picking {pr_commit}")

    try:
        # Find and fetch the PR commit and both release branches.
        branch_pat = re.compile(r"/release-(\d+)\.(\d+)\.(\d+)$")
        release_branch = max(
            (
                line.split()[1]
                for line in run_capture(
                    "git", "ls-remote", CLONED_REMOTE, "refs/heads/release-*"
                ).stdout.splitlines()
            ),
            key=lambda branch: [int(part) for part in branch_pat.search(branch).groups()],
        )[len("refs/heads/") :]
        print(f"Found release branch {release_branch}")

        run(
            "git",
            "fetch",
            "--depth=2",
            CLONED_REMOTE,
            pr_commit,
            f"{release_branch}:{release_branch}",
        )

        # Perform the cherry-pick and push.
        run("git", "config", "user.email", "automation@determined.ai")
        run("git", "config", "user.name", "Determined CI")
        run("git", "checkout", release_branch)
        cherry_pick_skipping_empty(pr_commit)
        run("git", "push", CLONED_REMOTE, f"{release_branch}:{release_branch}")

        print("Cherry-pick succeeded, updating item status")
        set_project_pr_status(current_project_id(), pr_id, FIX_UNRELEASED_STATUS)
    except subprocess.CalledProcessError:
        import traceback

        traceback.print_exc()
        print("Cherry-pick failed, adding PR as conflicted")
        set_project_pr_status(current_project_id(), pr_id, FIX_CONFLICT_STATUS)
        requests.post(
            "https://casper.internal.infra.determined.ai/hubot/conflict",
            headers={"X-Casper-Token": CASPER_TOKEN},
            json={"url": pr["url"], "logs_url": os.environ.get("LOGS_URL")},
        )


class Actions:
    @staticmethod
    def pr_merged(pr_id: str):
        pr_labels = gql.get_pr_labels(id=pr_id)["node"]["labels"]["nodes"]
        print("Labels of merged PR:", [label["name"] for label in pr_labels])
        if any(label["name"] == CHERRY_PICK_LABEL for label in pr_labels):
            print("Cherry-picking labeled merged PR")
            cherry_pick_pr(pr_id)
        else:
            title = gql.get_pr_title(id=pr_id)["node"]["title"]
            if re.match(r"(feat|fix)\S*:", title, re.IGNORECASE) is not None:
                print("Adding feat/fix PR")
            elif re.match(r"\S+:", title, re.IGNORECASE) is not None:
                print("Skipping non-feat/fix PR")
                return
            else:
                print("Adding PR of unknown type")

            print("Adding merged PR to next release project")
            add_tracking_issue_to_project(next_project_id(), pr_id, NEEDS_TESTING_STATUS)

    @staticmethod
    def pr_labeled(pr_id: str, label: str):
        if label != CHERRY_PICK_LABEL:
            return

        state = gql.get_pr_state(id=pr_id)["node"]["state"]
        if state == "OPEN":
            print("Adding labeled open PR to current release project")
            add_item_to_project(current_project_id(), pr_id, FIX_OPEN_STATUS)
        elif state == "MERGED":
            # TODO Maybe delete the tracking issue in the next release that was
            # created when this merged without a label.
            print("Cherry-picking labeled merged PR")
            add_item_to_project(current_project_id(), pr_id, FIX_OPEN_STATUS)
            cherry_pick_pr(pr_id)
        elif state == "CLOSED":
            print("Ignoring label addition to closed PR")

    @staticmethod
    def pr_unlabeled(pr_id: str, label: str):
        if label != CHERRY_PICK_LABEL:
            return

        state = gql.get_pr_state(id=pr_id)["node"]["state"]
        if state == "OPEN":
            print("Removing unlabeled open PR from current release project")
            gql.delete_project_item(project=current_project_id(), item=pr_id)
        else:
            print(f"Ignoring label removal from {state.lower()} PR")

    @staticmethod
    def cherry_pick_conflict_resolved(pr_id: str):
        # TODO Use Git to confirm the cherry-pick was done.
        project_id = current_project_id()
        set_project_pr_status(project_id, pr_id, FIX_UNRELEASED_STATUS)

    @staticmethod
    def release_unreleased_prs():
        after_cursor = None
        project = current_project_id()
        while True:
            items = gql.list_project_prs(project=project, after=after_cursor)["node"]["items"]
            print("batch:", len(items["nodes"]))
            for item in items["nodes"]:
                if item["fieldValueByName"]["name"] == FIX_UNRELEASED_STATUS and item["content"]:
                    print(project, item["content"])
                    add_tracking_issue_to_project(
                        project, item["content"]["id"], NEEDS_TESTING_STATUS
                    )
                    set_project_item_status(project, item["id"], FIX_RELEASED_STATUS)
            if not items["pageInfo"]["hasNextPage"]:
                break
            after_cursor = items["pageInfo"]["endCursor"]


def main(args):
    if not args:
        print("Must provide an action!")
        return 1

    action = args.pop(0).replace("-", "_")
    return getattr(Actions, action)(*args)


if __name__ == "__main__":
    exit(main(sys.argv[1:]))
